Skip to content

Add DecomposeLstmPass for ARM backend (#17140)#17140

Merged
digantdesai merged 3 commits into
pytorch:mainfrom
apullin:export-D92059277
Apr 15, 2026
Merged

Add DecomposeLstmPass for ARM backend (#17140)#17140
digantdesai merged 3 commits into
pytorch:mainfrom
apullin:export-D92059277

Conversation

@apullin

@apullin apullin commented Feb 3, 2026

Copy link
Copy Markdown
Contributor

Summary:

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
c_t = f_t * c_{t-1} + i_t * g_t
h_t = o_t * tanh(c_t)

Features:

  • Multi-layer LSTM support
  • Bidirectional LSTM support
  • With/without bias
  • batch_first support
  • Batched gate computation (2 mm ops per timestep instead of 8 )

Differential Revision: D92059277

@apullin apullin requested a review from digantdesai as a code owner February 3, 2026 07:35
@pytorch-bot

pytorch-bot Bot commented Feb 3, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17140

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 4 New Failures, 1 Cancelled Job, 3 Unrelated Failures

As of commit 7cd0fbd with merge base 26e2ab8 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 3, 2026
@meta-codesync

meta-codesync Bot commented Feb 3, 2026

Copy link
Copy Markdown
Contributor

@apullin has exported this pull request. If you are a Meta employee, you can view the originating Diff in D92059277.

@github-actions

github-actions Bot commented Feb 3, 2026

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:
Pull Request resolved: pytorch#17140

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)
 ---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059277
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:
Pull Request resolved: pytorch#17140

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)
 ---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059277
@apullin apullin force-pushed the export-D92059277 branch 3 times, most recently from bf1d013 to f57bc19 Compare February 3, 2026 23:23
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059277
@apullin apullin force-pushed the export-D92059277 branch 3 times, most recently from bc86171 to 62726e7 Compare February 3, 2026 23:55
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059277
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059277
@zingo zingo added partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm ciflow/trunk labels Feb 6, 2026
@pytorch-bot

pytorch-bot Bot commented Feb 6, 2026

Copy link
Copy Markdown

To add the ciflow label ciflow/trunk please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot Bot removed the ciflow/trunk label Feb 6, 2026
@zingo zingo changed the title Add DecomposeLstmPass for ARM backend Arm backend: Add DecomposeLstmPass Feb 6, 2026
pytorch-bot Bot pushed a commit that referenced this pull request Feb 6, 2026
Summary:

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059277
Comment thread backends/arm/tosa/backend.py Outdated
@gggekov

gggekov commented Feb 6, 2026

Copy link
Copy Markdown
Collaborator

What is the reason to decompose LSTM in the Arm backend rather than let torch.export.export decompose the LSTM ?

@gggekov

gggekov commented Feb 6, 2026

Copy link
Copy Markdown
Collaborator

Never mind- i see the torch.nn.LSTM is not decomposed in the torch.export.export as I thought initially.

@apullin

apullin commented Apr 2, 2026

Copy link
Copy Markdown
Contributor Author

@gggekov done, updated across all three commits:

  • Removed _add_lstm_workaround and test_decompose_recurrent_tosa_pipelines.py
  • All tests use TosaPipelineFP / TosaPipelineINT (basic, bidirectional, no_bias, multilayer)
  • LSTM FP tests skipped but present, INT is still faithfully tested
  • No feature smear between commits

@apullin apullin left a comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed related to MLETORCH-1266 have been removed, and MLETORCH-1266 has been settled by another PR.

@pytorch-bot

pytorch-bot Bot commented Apr 10, 2026

Copy link
Copy Markdown

To add the ciflow label ciflow/trunk please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

Comment thread backends/arm/tosa/partitioner.py Outdated
@gggekov

gggekov commented Apr 10, 2026

Copy link
Copy Markdown
Collaborator

Thank you @apullin !
I think you are on the final straight - do you mind to fix the lintrunner issue from the CI ?
I think if you do

$ pip install -r requirements-lintrunner.txt
$ lintrunner init 
$ lintrunner --revision 'HEAD^' 

on the latest commit, you should be able to reproduce the CI failure and fix it(should be quite easy fix).
I am also not sure I fully understand the rational for adding the ops in the backends/arm/tosa/partitioner.py - i believe you shouldn't need that.

We are very close to merging that!

Comment thread backends/arm/_passes/decompose_gru_pass.py Outdated
Comment thread backends/arm/_passes/decompose_lstm_pass.py Outdated
Comment thread backends/arm/_passes/decompose_rnn_pass.py Outdated
Comment thread backends/arm/test/passes/test_decompose_rnn_pass.py Outdated
Comment thread backends/arm/test/passes/test_decompose_gru_pass.py Outdated
Comment thread backends/arm/test/passes/test_decompose_lstm_pass.py Outdated

@gggekov gggekov left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added one more comment about the copyright mention at the top of each new file

@pytorch-bot

pytorch-bot Bot commented Apr 13, 2026

Copy link
Copy Markdown

To add the ciflow label ciflow/trunk please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@gggekov gggekov left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me to merge

@gggekov

gggekov commented Apr 13, 2026

Copy link
Copy Markdown
Collaborator

The merge button is greyed out, i believe a maintainer of ExecuTorch such as @digantdesai needs to approve in order to merge

@apullin

apullin commented Apr 13, 2026

Copy link
Copy Markdown
Contributor Author

@gggekov Ah, sadly, there's one final thing: It looks like there is something wrong with the decomposition where it uses BMM resulting in a failure in ConvertMmToBmmPass, but everything passes public CI and only fails on the internal Meta test stack. Really not sure why that is.

The workaround I tried & pushed is to replace mm(x, transpose(W)) with linear(x, W, bias) ... a little unfortunate since I know people love BMM. But that gets this passing without needing to add another commit for BMM.
Changes are located entirely in _build_direction.

Commits are updated with the changes that pass internally.

@zingo

zingo commented Apr 14, 2026

Copy link
Copy Markdown
Collaborator

@gggekov Ah, sadly, there's one final thing: It looks like there is something wrong with the decomposition where it uses BMM resulting in a failure in ConvertMmToBmmPass, but everything passes public CI and only fails on the internal Meta test stack. Really not sure why that is.

The workaround I tried & pushed is to replace mm(x, transpose(W)) with linear(x, W, bias) ... a little unfortunate since I know people love BMM. But that gets this passing without needing to add another commit for BMM. Changes are located entirely in _build_direction.

Commits are updated with the changes that pass internally.

Thanks for making sure is really good. Would it possible/ok to get some version of that test inte to tests on GitHub

@pytorch-bot

pytorch-bot Bot commented Apr 14, 2026

Copy link
Copy Markdown

To add the ciflow label ciflow/trunk please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

1 similar comment
@pytorch-bot

pytorch-bot Bot commented Apr 14, 2026

Copy link
Copy Markdown

To add the ciflow label ciflow/trunk please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@gggekov

gggekov commented Apr 14, 2026

Copy link
Copy Markdown
Collaborator

Restarting the CI as there was one failure for our vkml test cases.

@gggekov

gggekov commented Apr 14, 2026

Copy link
Copy Markdown
Collaborator

Think you need a stamp from @digantdesai and can then merge

@digantdesai

Copy link
Copy Markdown
Contributor

Rebase?

Andrew Pullin added 3 commits April 14, 2026 16:08
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
Summary:

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8 )

Differential Revision: D92059277
@gggekov

gggekov commented Apr 15, 2026

Copy link
Copy Markdown
Collaborator

Thanks a lot for the contribution, @apullin !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants